131. 分割回文串

链接:https://leetcode-cn.com/problems/palindrome-partitioning/

题目描述

给定一个字符串 s,将 s 分割成一些子串,使每个子串都是回文串。
返回 s 所有可能的分割方案。

示例:

1
2
3
4
5
6
输入: "aab"
输出:
[
["aa","b"],
["a","a","b"]
]

题目分析

递归结构

$r(n)=r(n-1)+w$
其中w为任选的一个长度小于当前剩余长度的子串,并且为回文串

递归边界

1
2
3
if(n == len(s)){
return
}

递归参数

  • s
  • n
  • result
  • result_all

其他

其他方面还需要

  • 一个回文数判断的函数

答案

按照上面的思路,也就是最基本的回溯法的思路:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution:

def __init__(self):
self.result_all = None

def partition(self, s: str) -> List[List[str]]:
self.result_all = []
self.dfs(s, 0, [])
return self.result_all

def dfs(self, s, n, result):
if len(s) == n:
self.result_all.append(list(result))
return
for i in range(n + 1, len(s) + 1):
sub_s = s[n: i]
if not self.palindrome_validate(sub_s):
continue
result.append(sub_s)
self.dfs(s, i, result)
result.pop()
return

def palindrome_validate(self, s):
return s == s[::-1]

这样大致提交结果如下:
!Alt text

可以发现空间上还是可以的,但是时间是比较慢的!

思路递进1:记忆化搜索

递归中一种最常用的剪枝的方法就是记忆化搜索,本题明显可以使用这种剪枝的方法。
整体写起来也是比较简单的,就是把所有字串是否是回文数缓存起来,这样递归的时候用起来就不用重复去判断了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution:

def __init__(self):
self.result_all = None
self.s_dict = None

def partition(self, s: str) -> List[List[str]]:
self.result_all = []
self.s_dict = {}
self.load_dict(s)
self.dfs(s, 0, [])
return self.result_all

def dfs(self, s, n, result):
if len(s) == n:
# print(result)
self.result_all.append(result[:])
return
for i in range(n + 1, len(s) + 1):
sub_s = s[n: i]
if not self.s_dict[sub_s]:
continue
result.append(sub_s)
self.dfs(s, i, result)
result.pop()
return

def load_dict(self, s):
for i in range(len(s)):
for j in range(i + 1, len(s) + 1):
sub_s = s[i:j]
if sub_s not in self.s_dict:
self.s_dict[sub_s] = self.palindrome_validate(sub_s)


def palindrome_validate(self, s):
return s==s[::-1]

可以明显看到这样子执行空间消耗虽然大了不少,但是时间明显还是提升了。
!Alt text

思路递进2:动态规划

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:

def __init__(self):
self.result_all = None
self.s_dict = None

def partition(self, s: str) -> List[List[str]]:
# dp[i] = s[0...i]
# dp[i+1]={dp[i] + s[i+1], dp[i-1] * s[i..i+1], ... , dp[0] * s[1..i+1]}
dp = [[] for _ in range(len(s) + 1)]
for i in range(1, len(s)+1):
for j in range(i):
if self.palindrome_validate(s[j:i]):
if len(dp[j]) > 0:
for l in dp[j]:
dp[i].append(l + [s[j:i]])
else:
dp[i].append([s[j:i]])

return dp[-1]


def palindrome_validate(self, s):
return s==s[::-1]

!Alt text